Skip to content

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Jun 20, 2025

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby changed the title Support gpa in aten spda Support gqa in aten spda Jun 20, 2025
Copy link

codecov bot commented Jun 20, 2025

Codecov Report

❌ Patch coverage is 1.85185% with 53 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.24%. Comparing base (03ab4c5) to head (16d75e9).
⚠️ Report is 94 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/nn.py 1.85% 48 Missing and 5 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2408      +/-   ##
==========================================
- Coverage   70.38%   70.24%   -0.14%     
==========================================
  Files         199      199              
  Lines       25223    25270      +47     
  Branches     2686     2693       +7     
==========================================
- Hits        17753    17751       -2     
- Misses       6541     6586      +45     
- Partials      929      933       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Comment on lines +1994 to +1996
return _aten_scaled_dot_product_attention_bool_mask_onnx(
query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa
)

Check failure

Code scanning / CodeQL

Wrong name for an argument in a call

Keyword argument 'enable_gqa' is not a supported parameter name of [function _aten_scaled_dot_product_attention_bool_mask_onnx](1).

Copilot Autofix

AI 3 months ago

To fix the issue, the keyword argument enable_gqa should be removed from the call to _aten_scaled_dot_product_attention_bool_mask_onnx on line 1994. This ensures that the function is called with only the parameters it supports. The removal of enable_gqa will not affect the functionality of _aten_scaled_dot_product_attention_bool_mask_onnx, as it does not use this argument.

Suggested changeset 1
onnxscript/function_libs/torch_lib/ops/nn.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py
--- a/onnxscript/function_libs/torch_lib/ops/nn.py
+++ b/onnxscript/function_libs/torch_lib/ops/nn.py
@@ -1994,3 +1994,3 @@
         return _aten_scaled_dot_product_attention_bool_mask_onnx(
-            query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa
+            query, key, value, attn_mask, scale, dropout_p
         )
EOF
@@ -1994,3 +1994,3 @@
return _aten_scaled_dot_product_attention_bool_mask_onnx(
query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa
query, key, value, attn_mask, scale, dropout_p
)
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
@justinchuby justinchuby marked this pull request as draft June 20, 2025 19:14
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
axis=0
)
value_unsqueezed = op.Unsqueeze(value, [-2])
value_tiled = op.Tile(value_unsqueezed, op.Concat(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op.Tile does not align to PyTorch inplementation.

    if (
        (q_num_heads != k_num_heads)
        and (q_num_heads % k_num_heads == 0)
        and (k_num_heads == v_num_heads)
    ):
        seq_reps = q_num_heads // k_num_heads
        # Interleave-repeat each KV head: [h0, h0, h1, h1, ...]
        K = np.repeat(K, repeats=seq_reps, axis=1)
        V = np.repeat(V, repeats=seq_reps, axis=1)

We should be able to reuse repeat_interleave here when it's done.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use expand for repeat interleave for simplicity over tile?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, https://github.com/onnx/onnx/blob/62f2facfc29b0b6a26247614d56e7c294a6206fc/onnx/defs/nn/defs.cc#L3840-L3856

I wonder if we can just adapt whatever function body is in defs.cc to torchlib? Is there any difference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. I must have need using the old implementation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

[ONNX] Support for grouped query attention
2 participants